Linear Attention: Transformers are RNNs
Linear Attention
$ Attention(Q, K, V) = sortmax(\frac{QK^T}{\sqrt{d_{key}}})V
$ Attention(Q, K, V)_i = \frac{\sum_{j=1}^n\exp(q_i^Tk_j)\cdot v_j}{\sum_{j=1}^n\exp(q_i^Tk_j)}
O(n^2)の部分をどうにかしたい
O(n)に落としたい → Linear Attention
とにかく類似度の計算ができれば良いので, 別の類似度計算に置き換えたい
simでまとめると
$ sim(q, k)=exp(\frac{q^Tk}{\sqrt{d_{key}}})
$ Attention(Q, K, V)_i = \frac{\sum_{j=1}^nsim(q_i, k_j)\cdot v_j}{\sum_{j=1}^nsim(q_i, k_j)}
q_iとk_jに依存しているので, 乗法に分離できると嬉しい
前処理O(n)でO(nd)に
$ \phi(x) に$ elu(x) + 1を採用 (ELU) 結構速度・精度ともに良いらしい